#%%
import gurobipy as gp
from gurobipy import GRB
from gurobipy import quicksum as qsum
import numpy as np, random


# %%
def solve_inverse_opt_strict_batch(
    observations, 
    cp, cu, ch, Ts,
    v_limits, w_limits,
    obstacles,
    K
):
    """
    Solves the inverse optimization problem for a batch of observations
    using a strict KKT formulation.
    
    Args:
        observations (list of dicts): Each dict contains an observation, e.g.,
            {'v_obs': ..., 'w_obs': ..., 'p_t': ..., 'theta_t': ..., 
             'p_hat': ..., 'obstacles': ...}
        K (int): Number of hinge-loss terms.
    """
    num_observations = len(observations)
    
    m = gp.Model("InverseOpt_StrictBatch")
    m.setParam("NonConvex", 2)

    # --- 1. Define GLOBAL Variables (the parameters we want to find) ---
    alpha = m.addMVar((K, 2), lb=-GRB.INFINITY, name="alpha")
    beta = m.addMVar(K, lb=-GRB.INFINITY, name="beta")

    # --- 2. Define LOCAL Variables (one set for each observation) ---
    # We use a Gurobi MVar with a dimension for each observation
    z_star = m.addMVar((num_observations, K), lb=0, name="z_star")
    lambda_v_min = m.addMVar(num_observations, lb=0, name="lambda_v_min")
    lambda_v_max = m.addMVar(num_observations, lb=0, name="lambda_v_max")
    lambda_w_min = m.addMVar(num_observations, lb=0, name="lambda_w_min")
    lambda_w_max = m.addMVar(num_observations, lb=0, name="lambda_w_max")
    gamma = m.addMVar((num_observations, K), lb=0, name="gamma")
    delta = m.addMVar((num_observations, K), lb=0, name="delta")
    # epsilon = m.addVar(lb=0, name="epsilon") # Small positive constant for numerical stability
    slack_v = m.addMVar(num_observations, lb=0, name="slack_v")
    slack_w = m.addMVar(num_observations, lb=0, name="slack_w")

    # Handling obstacles is tricky as their number can vary per observation
    # We create a list of lists for mu variables
    num_obstacles = len(obstacles)
    mu_vars = m.addMVar((num_observations, num_obstacles), lb=0, name="mu")
    # mu_vars = []
    # for i, obs in enumerate(observations):
    #     num_obs_i = len(obs.get('obstacles', []))
    #     if num_obs_i > 0:
    #         mu_vars.append(m.addMVar(num_obs_i, lb=0, name=f"mu_{i}"))
    #     else:
    #         mu_vars.append(None)

    # --- 3. Set Objective Function (on GLOBAL variables) ---
    alpha_norm_sq = gp.quicksum(alpha[k, j]**2 for k in range(K) for j in range(2))
    beta_norm_sq = gp.quicksum(beta[k]**2 for k in range(K))

    slack_sum = gp.quicksum(slack_v) + gp.quicksum(slack_w)
    
    m.setObjective(alpha_norm_sq + beta_norm_sq + 30 * slack_sum, GRB.MINIMIZE)

    # m.setObjective(alpha_norm_sq + beta_norm_sq, GRB.MINIMIZE)

    # --- 4. Add KKT Constraints for EACH observation ---
    for i, obs in enumerate(observations):
        # Unpack data for observation i
        v_t_obs, w_t_obs = obs['v_obs'], obs['w_obs']
        p_t, theta_t = obs['p_t'], obs['theta_t']
        p_hat = obs['p_hat']
        # obstacles = obs.get('obstacles', [])
        num_obstacles_i = len(obstacles)
        
        # Pre-compute constants for this observation
        d_t = np.array([np.cos(theta_t), np.sin(theta_t)])
        p_t1_obs = p_t + Ts * v_t_obs * d_t

        # == Forward constraints on z_star for observation i ==
        for k in range(K):
            m.addConstr(z_star[i, k] >= alpha[k, 0] * v_t_obs + alpha[k, 1] * w_t_obs - beta[k], name=f"z_fwd_{i}_{k}")

        # == Stationarity Conditions for observation i ==
        grad_v_orig = (
            2 * cp * (p_t1_obs - p_hat) @ (Ts * d_t) + 2 * cu * v_t_obs
            - lambda_v_min[i] + lambda_v_max[i]
        )
        if num_obstacles_i > 0:
            for j in range(num_obstacles_i):
                print(obstacles[j])
                print(p_t1_obs)
                grad_v_orig += -2 * mu_vars[i][j] * (p_t1_obs - obstacles[j]['p_obs']) @ (Ts * d_t)
        theta_t1_obs = theta_t + Ts * w_t_obs
        theta_hat = np.arctan2(p_hat[1]-p_t1_obs[1], p_hat[0]-p_t1_obs[0])

        new_term_w = 2 * ch * (theta_t1_obs - theta_hat) * Ts
        grad_w_orig = 2 * cu * w_t_obs - lambda_w_min[i] + lambda_w_max[i] + new_term_w
       
        grad_v_full_expr = grad_v_orig + gp.quicksum(gamma[i, k] * alpha[k, 0] for k in range(K))
        grad_w_full_expr = grad_w_orig + gp.quicksum(gamma[i, k] * alpha[k, 1] for k in range(K))

        # Instead of == 0, we constrain the expression to be within [-slack, +slack]
        m.addConstr(grad_v_full_expr <= slack_v[i], name=f"stat_v_upper_{i}")
        m.addConstr(grad_v_full_expr >= -slack_v[i], name=f"stat_v_lower_{i}")
        m.addConstr(grad_w_full_expr <= slack_w[i], name=f"stat_w_upper_{i}")
        m.addConstr(grad_w_full_expr >= -slack_w[i], name=f"stat_w_lower_{i}")

        # m.addConstr(grad_v_orig + gp.quicksum(gamma[i, k] * alpha[k, 0] for k in range(K)) == 0, name=f"stat_v_{i}")
        # m.addConstr(grad_w_orig + gp.quicksum(gamma[i, k] * alpha[k, 1] for k in range(K)) == 0, name=f"stat_w_{i}")

        for k in range(K):
            m.addConstr(gamma[i, k] + delta[i, k] == 1, name=f"stat_z_{i}_{k}")

        # == Complementary Slackness for observation i ==
        epsilon = 1e-6
        # epsilon = 2
        if v_limits[0] - v_t_obs < -epsilon: m.addConstr(lambda_v_min[i] == 0)
        if v_t_obs - v_limits[1] < -epsilon: m.addConstr(lambda_v_max[i] == 0)
        if w_limits[0] - w_t_obs < -epsilon: m.addConstr(lambda_w_min[i] == 0)
        if w_t_obs - w_limits[1] < -epsilon: m.addConstr(lambda_w_max[i] == 0)
        
        if num_obstacles_i > 0:
            for j in range(num_obstacles_i):
                gap = (obstacles[j]['r'] + obstacles[j]['robot_r'] + obstacles[j]['d_safe'])**2 - np.sum(np.square(p_t1_obs - obstacles[j]['p_obs']))
                if gap < -epsilon: m.addConstr(mu_vars[i][j] == 0)

        for k in range(K):
            expr = z_star[i, k] - (alpha[k, 0] * v_t_obs + alpha[k, 1] * w_t_obs - beta[k])
            m.addConstr(gamma[i, k] * expr == 0, name=f"cs_gamma_{i}_{k}")
            m.addConstr(delta[i, k] * z_star[i, k] == 0, name=f"cs_delta_{i}_{k}")

    # --- 5. Solve the model ---
    m.optimize()

    if m.Status == GRB.OPTIMAL:
        print("Successfully solved the strict batch inverse optimization problem.")
        return alpha.X, beta.X
    else:
        print(f"Inverse optimization failed with status code: {m.Status}")
        if m.Status == GRB.INFEASIBLE:
            print("Model is infeasible. This is common with strict batch models.")
            m.computeIIS()
            m.write("model_strict.ilp")
            print("IIS written to model_strict.ilp")
        return None, None
    
def solve_inverse_opt_strict_batch_v2(
    observations, 
    cp, cu, ch, Ts,
    v_limits, w_limits,
    obstacles,
    K
):
    """
    Solves the inverse optimization problem for a batch of observations
    using a strict KKT formulation.
    
    Args:
        observations (list of dicts): Each dict contains an observation, e.g.,
            {'v_obs': ..., 'w_obs': ..., 'p_t': ..., 'theta_t': ..., 
             'p_hat': ..., 'obstacles': ...}
        K (int): Number of hinge-loss terms.
    """
    num_observations = len(observations)
    
    m = gp.Model("InverseOpt_StrictBatch")
    m.setParam("NonConvex", 2)

    # --- 1. Define GLOBAL Variables (the parameters we want to find) ---
    bound = 50
    # alpha = m.addMVar((K, 2), lb=-GRB.INFINITY, name="alpha")
    # beta = m.addMVar(K, lb=-GRB.INFINITY, name="beta")
    alpha = m.addMVar((K, 2), lb=-bound, ub=bound, name="alpha")
    beta = m.addMVar(K, lb=-bound, ub=bound, name="beta")

    # --- 2. Define LOCAL Variables (one set for each observation) ---
    # We use a Gurobi MVar with a dimension for each observation
    z_star = m.addMVar((num_observations, K), lb=0, name="z_star")
    lambda_v_min = m.addMVar(num_observations, lb=0, name="lambda_v_min")
    lambda_v_max = m.addMVar(num_observations, lb=0, name="lambda_v_max")
    lambda_w_min = m.addMVar(num_observations, lb=0, name="lambda_w_min")
    lambda_w_max = m.addMVar(num_observations, lb=0, name="lambda_w_max")
    gamma = m.addMVar((num_observations, K), lb=0, name="gamma")
    delta = m.addMVar((num_observations, K), lb=0, name="delta")
    # epsilon = m.addVar(lb=0, name="epsilon") # Small positive constant for numerical stability
    slack = m.addMVar(num_observations, lb=0, name="slack")

    # Handling obstacles is tricky as their number can vary per observation
    # We create a list of lists for mu variables
    num_obstacles = len(obstacles)
    mu_vars = m.addMVar((num_observations, num_obstacles), lb=0, name="mu")
    # mu_vars = []
    # for i, obs in enumerate(observations):
    #     num_obs_i = len(obs.get('obstacles', []))
    #     if num_obs_i > 0:
    #         mu_vars.append(m.addMVar(num_obs_i, lb=0, name=f"mu_{i}"))
    #     else:
    #         mu_vars.append(None)

    # --- 3. Set Objective Function (on GLOBAL variables) ---
    alpha_norm_sq = gp.quicksum(alpha[k, j]**2 for k in range(K) for j in range(2))
    beta_norm_sq = gp.quicksum(beta[k]**2 for k in range(K))

    slack_sum = gp.quicksum(slack)
    
    m.setObjective(alpha_norm_sq + beta_norm_sq + 30 * slack_sum, GRB.MINIMIZE)

    # m.setObjective(alpha_norm_sq + beta_norm_sq, GRB.MINIMIZE)

    # --- 4. Add KKT Constraints for EACH observation ---
    for i, obs in enumerate(observations):
        # Unpack data for observation i
        v_t_obs, w_t_obs = obs['v_obs'], obs['w_obs']
        p_t, theta_t = obs['p_t'], obs['theta_t']
        p_hat = obs['p_hat']
       
        # obstacles = obs.get('obstacles', [])
        num_obstacles_i = len(obstacles)
        
        # Pre-compute constants for this observation
        d_t = np.array([np.cos(theta_t), np.sin(theta_t)])
        p_t1_obs = p_t + Ts * v_t_obs * d_t
        theta_t1_obs = theta_t + Ts * w_t_obs
        theta_hat = np.arctan2(p_hat[1]-p_t1_obs[1], p_hat[0]-p_t1_obs[0])

        # == Forward constraints on z_star for observation i ==
        for k in range(K):
            m.addConstr(z_star[i, k] >= alpha[k, 0] * v_t_obs + alpha[k, 1] * w_t_obs - beta[k], name=f"z_fwd_{i}_{k}")

        # == Stationarity Conditions for observation i ==
        grad_v_orig = (
            2 * cp * (p_t1_obs - p_hat) @ (Ts * d_t) + 2 * cu * v_t_obs
            - lambda_v_min[i] + lambda_v_max[i]
        )
        if num_obstacles_i > 0:
            for j in range(num_obstacles_i):
                grad_v_orig += -2 * mu_vars[i][j] * (p_t1_obs - obstacles[j]['p_obs']) @ (Ts * d_t)

        # grad_w_orig = 2 * cu * w_t_obs - lambda_w_min[i] + lambda_w_max[i]
        new_term_w = 2 * ch * (theta_t1_obs - theta_hat) * Ts
        grad_w_orig = 2 * cu * w_t_obs - lambda_w_min[i] + lambda_w_max[i] + new_term_w
        
        m.addConstr(grad_v_orig + gp.quicksum(gamma[i, k] * alpha[k, 0] for k in range(K)) == 0, name=f"stat_v_{i}")
        m.addConstr(grad_w_orig + gp.quicksum(gamma[i, k] * alpha[k, 1] for k in range(K)) == 0, name=f"stat_w_{i}")

        for k in range(K):
            m.addConstr(gamma[i, k] + delta[i, k] == 1, name=f"stat_z_{i}_{k}")

        # == Complementary Slackness for observation i ==
        epsilon = 1e-6
        # epsilon = 2
        exp1 = (v_limits[0] - v_t_obs)*lambda_v_min[i] + (v_t_obs - v_limits[1])*lambda_v_max[i] + (w_limits[0] - w_t_obs)*lambda_w_min[i] + (w_t_obs - w_limits[1])*lambda_w_max[i]
        
        exp2 = 0
        if num_obstacles_i > 0:
            for j in range(num_obstacles_i):
                gap = (obstacles[j]['r'] + obstacles[j]['robot_r'] + obstacles[j]['d_safe'])**2 - np.sum(np.square(p_t1_obs - obstacles[j]['p_obs']))
                exp2 += mu_vars[i][j] * gap
        exp3 = 0
        for k in range(K):
            expr = (alpha[k, 0] * v_t_obs + alpha[k, 1] * w_t_obs - beta[k]) - z_star[i, k]
            exp3 += expr*gamma[i, k] - z_star[i, k]*delta[i, k]
        m.addConstr(exp1 + exp2 + exp3 + slack[i] == 0, name=f"cs_slack_{i}")
    m.setParam("TimeLimit", 300)  
    # --- 5. Solve the model ---
    m.optimize()

    # if m.Status == GRB.OPTIMAL:
    #     print("Successfully solved the strict batch inverse optimization problem.")
    #     return alpha.X, beta.X
    # else:
    #     print(f"Inverse optimization failed with status code: {m.Status}")
    #     if m.Status == GRB.INFEASIBLE:
    #         print("Model is infeasible. This is common with strict batch models.")
    #         m.computeIIS()
    #         m.write("model_strict.ilp")
    #         print("IIS written to model_strict.ilp")
    #     return None, None
    return alpha.X, beta.X

    
def solve_inverse_opt_strict_batch_v3(
    observations, 
    cp, cu, ch, Ts,
    v_limits, w_limits,
    obstacles,
    K
):
    """
    Solves the inverse optimization problem for a batch of observations
    using a strict KKT formulation.
    
    Args:
        observations (list of dicts): Each dict contains an observation, e.g.,
            {'v_obs': ..., 'w_obs': ..., 'p_t': ..., 'theta_t': ..., 
             'p_hat': ..., 'obstacles': ...}
        K (int): Number of hinge-loss terms.
    """
    num_observations = len(observations)
    
    m = gp.Model("InverseOpt_StrictBatch")
    m.setParam("NonConvex", 2)

    # --- 1. Define GLOBAL Variables (the parameters we want to find) ---
    bound = 10
    # alpha = m.addMVar((K, 2), lb=-GRB.INFINITY, name="alpha")
    # beta = m.addMVar(K, lb=-GRB.INFINITY, name="beta")
    alpha = m.addMVar((K, 2), lb=-bound, ub=bound, name="alpha")
    beta = m.addMVar(K, lb=-bound, ub=bound, name="beta")

    # --- 2. Define LOCAL Variables (one set for each observation) ---
    # We use a Gurobi MVar with a dimension for each observation
    z_star = m.addMVar((num_observations, K), lb=0, name="z_star")
    lambda_v_min = m.addMVar(num_observations, lb=0, name="lambda_v_min")
    lambda_v_max = m.addMVar(num_observations, lb=0, name="lambda_v_max")
    lambda_w_min = m.addMVar(num_observations, lb=0, name="lambda_w_min")
    lambda_w_max = m.addMVar(num_observations, lb=0, name="lambda_w_max")
    gamma = m.addMVar((num_observations, K), lb=0, name="gamma")
    delta = m.addMVar((num_observations, K), lb=0, name="delta")
    # epsilon = m.addVar(lb=0, name="epsilon") # Small positive constant for numerical stability
    slack = m.addMVar(num_observations, lb=0, name="slack")
    z = m.addMVar((num_observations, 2), lb=0, name="z")  # New variable for position hinge loss

    # Handling obstacles is tricky as their number can vary per observation
    # We create a list of lists for mu variables
    num_obstacles = len(obstacles)
    mu_vars = m.addMVar((num_observations, num_obstacles), lb=0, name="mu")
    # mu_vars = []
    # for i, obs in enumerate(observations):
    #     num_obs_i = len(obs.get('obstacles', []))
    #     if num_obs_i > 0:
    #         mu_vars.append(m.addMVar(num_obs_i, lb=0, name=f"mu_{i}"))
    #     else:
    #         mu_vars.append(None)

    # --- 3. Set Objective Function (on GLOBAL variables) ---
    alpha_norm_sq = gp.quicksum(alpha[k, j]**2 for k in range(K) for j in range(2))
    beta_norm_sq = gp.quicksum(beta[k]**2 for k in range(K))

    slack_sum = gp.quicksum(slack)
    
    m.setObjective(alpha_norm_sq + beta_norm_sq + 30 * slack_sum, GRB.MINIMIZE)

    # m.setObjective(alpha_norm_sq + beta_norm_sq, GRB.MINIMIZE)

    # --- 4. Add KKT Constraints for EACH observation ---
    for i, obs in enumerate(observations):
        # Unpack data for observation i
        v_t_obs, w_t_obs = obs['v_obs'], obs['w_obs']
        p_t, theta_t = obs['p_t'], obs['theta_t']
        p_hat = obs['p_hat']
       
        # obstacles = obs.get('obstacles', [])
        num_obstacles_i = len(obstacles)
        
        # Pre-compute constants for this observation
        d_t = np.array([np.cos(theta_t), np.sin(theta_t)])
        p_t1_obs = p_t + Ts * v_t_obs * d_t
        theta_t1_obs = theta_t + Ts * w_t_obs
        theta_hat = np.arctan2(p_hat[1]-p_t1_obs[1], p_hat[0]-p_t1_obs[0])
        p_err_obs_vec = p_t1_obs - p_hat
        m.addConstr(z[i, :] >= p_err_obs_vec, name=f"pos_hinge_x_{i}")
        m.addConstr(z[i, :] >= -p_err_obs_vec, name=f"pos_hinge_x_{i}")

        # --- MODIFIED: Forward constraints on z_star based on new objective ---
        for k in range(K):
            m.addConstr(z_star[i, k] >= alpha[k, 0] * z[i,0] + alpha[k, 1] * z[i,1] - beta[k])
        
        # --- MODIFIED: Stationarity for v_t ---
        # Old cp term is gone. New term comes from the position hinge loss.
        grad_v_from_cu = 2 * cu * v_t_obs - lambda_v_min[i] + lambda_v_max[i]
        grad_v_from_obs = gp.quicksum(-2 * mu_vars[i][j] * (p_t1_obs - obstacles[j]['p_obs']) @ (Ts * d_t) for j in range(num_obstacles_i))
        # The new term is sum(gamma * alpha * d_p/d_v)
        grad_v_from_pos_hinge = gp.quicksum(gamma[i, k] * (alpha[k, 0] * Ts * d_t[0] + alpha[k, 1] * Ts * d_t[1]) for k in range(K))
        grad_v_full = grad_v_from_cu + grad_v_from_obs + grad_v_from_pos_hinge
        
        # grad_w_orig = 2 * cu * w_t_obs - lambda_w_min[i] + lambda_w_max[i]
        new_term_w = 2 * ch * (theta_t1_obs - theta_hat) * Ts
        grad_w_orig = 2 * cu * w_t_obs - lambda_w_min[i] + lambda_w_max[i] + new_term_w
        
        m.addConstr(grad_v_full == 0, name=f"stat_v_{i}")
        m.addConstr(grad_w_orig == 0, name=f"stat_v_{i}")

        for k in range(K):
            m.addConstr(gamma[i, k] + delta[i, k] == 1, name=f"stat_z_{i}_{k}")

        # == Complementary Slackness for observation i ==
        epsilon = 1e-6
        # epsilon = 2
        exp1 = (v_limits[0] - v_t_obs)*lambda_v_min[i] + (v_t_obs - v_limits[1])*lambda_v_max[i] + (w_limits[0] - w_t_obs)*lambda_w_min[i] + (w_t_obs - w_limits[1])*lambda_w_max[i]
        
        exp2 = 0
        if num_obstacles_i > 0:
            for j in range(num_obstacles_i):
                gap = (obstacles[j]['r'] + obstacles[j]['robot_r'] + obstacles[j]['d_safe'])**2 - np.sum(np.square(p_t1_obs - obstacles[j]['p_obs']))
                exp2 += mu_vars[i][j] * gap
        exp3 = 0
        for k in range(K):
            expr = (alpha[k, 0] * z[i,0] + alpha[k, 1] * z[i,1] - beta[k]) - z_star[i, k]
            exp3 += expr*gamma[i, k] - z_star[i, k]*delta[i, k]
        
        m.addConstr(exp1 + exp2 + exp3 + slack[i] == 0, name=f"cs_slack_{i}")
    m.setParam("TimeLimit", 300)  
    # --- 5. Solve the model ---
    m.optimize()

    # if m.Status == GRB.OPTIMAL:
    #     print("Successfully solved the strict batch inverse optimization problem.")
    #     return alpha.X, beta.X
    # else:
    #     print(f"Inverse optimization failed with status code: {m.Status}")
    #     if m.Status == GRB.INFEASIBLE:
    #         print("Model is infeasible. This is common with strict batch models.")
    #         m.computeIIS()
    #         m.write("model_strict.ilp")
    #         print("IIS written to model_strict.ilp")
    #     return None, None
    return alpha.X, beta.X

#%%
traj = np.load('traj.npy')
u_mpc = np.load('u_hist.npy')

# %%
# --- Example Usage ---
if __name__ == '__main__':
    # Define problem parameters
    # p_t_current = np.array([1.0, 1.0])
    # theta_t_current = np.pi / 4
    
    # p_hat is now a KNOWN parameter
    # p_hat_known = np.array([10.0, 10.0])
    
    # These are the "observed" optimal controls
    # v_t_observed = 1.8 
    # w_t_observed = -0.5

    observations = []
    for i in range(traj.shape[0]-1):
        observations.append({
                        'v_obs': u_mpc[i, 0],
                        'w_obs': u_mpc[i, 1],
                        'p_t': (traj[i, 0], traj[i, 1]),
                        'theta_t': traj[i, 2],
                        'p_hat': (traj[i+1, 0], traj[i+1, 1])
                    })
        
    CP, CU, CH, TS = 1.0, 0.01, 0.5, 0.1
    V_LIMITS = (0.0, 1)
    W_LIMITS = (-1.2, 1.2)
    OBSTACLES = [{'p_obs': np.array([2, 1]), 'r':0.3, 'robot_r': 0.12, 'd_safe': 0.1}]
    
    # Number of hinge-loss terms to infer
    K_TERMS = 1

    alpha_sol, beta_sol = solve_inverse_opt_strict_batch(
        observations = observations, 
        cp=CP, cu=CU, ch=CH, Ts=TS,
        v_limits=V_LIMITS,
        w_limits=W_LIMITS,
        obstacles=OBSTACLES,
        K=K_TERMS
    )
    
    if alpha_sol is not None:
        print("\nSolved unknown parameters:")
        for k in range(K_TERMS):
            print(f"Term k={k+1}:")
            print(f"  alpha_{k+1} = ({alpha_sol[k, 0]:.4f}, {alpha_sol[k, 1]:.4f})")
            print(f"  beta_{k+1}  = {beta_sol[k]:.4f}")
  

# %%
np.save('alpha.npy', alpha_sol)
np.save('beta.npy', beta_sol)

# %%
print(alpha_sol)
# %%
